#!/usr/bin/env python3
# coding=utf8
# Copyright  2022  Zhenxiang MA, Jiayu DU (SpeechColab)

import sys, os
import argparse
from typing import Iterable
import json, csv
import logging
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')

import pynini
from pynini.lib import pynutil

# reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py
# to import original lib:
#     from pynini.lib.edit_transducer import EditTransducer
class EditTransducer:
    DELETE = "<delete>"
    INSERT = "<insert>"
    SUBSTITUTE = "<substitute>"

    def __init__(self,
        symbol_table,
        vocab: Iterable[str],
        insert_cost: float = 1.0,
        delete_cost: float = 1.0,
        substitute_cost: float = 1.0,
        bound: int = 0,
    ) :
        # Left factor; note that we divide the edit costs by two because they also
        # will be incurred when traversing the right factor.
        sigma = pynini.union(
            *[ pynini.accep(token, token_type = symbol_table) for token in vocab ], 
        ).optimize()

        insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2)
        delete = pynini.cross(
            sigma,
            pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2)
        )
        substitute = pynini.cross(
            sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2)
        )

        edit = pynini.union(insert, delete, substitute).optimize()

        if bound:
            sigma_star = pynini.closure(sigma)
            self._e_i = sigma_star.copy()
            for _ in range(bound):
                self._e_i.concat(edit.ques).concat(sigma_star)
        else:
            self._e_i = edit.union(sigma).closure()

        self._e_i.optimize()

        right_factor_std = EditTransducer._right_factor(self._e_i)
        # right_factor_ext allows 0-cost matching between token's raw form & auxiliary form
        # e.g.: 'I' -> 'I#', 'AM' -> 'AM#'
        right_factor_ext = pynini.union(*[
            pynini.cross(
                pynini.accep(x,     token_type = symbol_table),
                pynini.accep(x+'#', token_type = symbol_table),
            ) for x in vocab
        ]).optimize().closure()
        self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize()

    @staticmethod
    def _right_factor(ifst: pynini.Fst) -> pynini.Fst:
        ofst = pynini.invert(ifst)
        syms = pynini.generated_symbols()
        insert_label = syms.find(EditTransducer.INSERT)
        delete_label = syms.find(EditTransducer.DELETE)
        pairs = [(insert_label, delete_label), (delete_label, insert_label)]
        right_factor = ofst.relabel_pairs(ipairs=pairs)
        return right_factor

    def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst:
        lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr)
        EditTransducer.check_wellformed_lattice(lattice)
        return lattice

    @staticmethod
    def check_wellformed_lattice(lattice: pynini.Fst) -> None:
        if lattice.start() == pynini.NO_STATE_ID:
            raise RuntimeError("Edit distance composition lattice is empty.")

    def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float:
        lattice = self.create_lattice(iexpr, oexpr)
        # The shortest cost from all final states to the start state is
        # equivalent to the cost of the shortest path.
        start = lattice.start()
        return float(pynini.shortestdistance(lattice, reverse=True)[start])
  
    def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike:
        lattice = self.create_lattice(iexpr, oexpr)
        alignment = pynini.shortestpath(lattice, nshortest=1, unique=True)
        return alignment.optimize()


class ErrorStats:
    def __init__(self):
        self.num_ref_utts = 0
        self.num_hyp_utts = 0
        self.num_eval_utts = 0 # in both ref & hyp
        self.num_hyp_without_ref = 0

        self.C = 0
        self.S = 0
        self.I = 0
        self.D = 0
        self.token_error_rate = 0.0
        self.modified_token_error_rate = 0.0

        self.num_utts_with_error = 0
        self.sentence_error_rate = 0.0
    
    def to_json(self):
        #return json.dumps(self.__dict__, indent=4)
        return json.dumps(self.__dict__)
    
    def to_kaldi(self):
        info = (
            F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
            F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
        )
        return info
    
    def to_summary(self):
        summary = (
            '==================== Overall Statistics ====================\n'
            F'num_ref_utts: {self.num_ref_utts}\n'
            F'num_hyp_utts: {self.num_hyp_utts}\n'
            F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
            F'num_eval_utts: {self.num_eval_utts}\n'
            F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
            F'token_error_rate: {self.token_error_rate:.2f}%\n'
            F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n'
            F'token_stats:\n'
            F'  - tokens:{self.C + self.S + self.D:>7}\n'
            F'  - edits: {self.S + self.I + self.D:>7}\n'
            F'  - cor:   {self.C:>7}\n'
            F'  - sub:   {self.S:>7}\n'
            F'  - ins:   {self.I:>7}\n'
            F'  - del:   {self.D:>7}\n'
            '============================================================\n'
        )
        return summary


class Utterance:
    def __init__(self, uid, text):
        self.uid = uid
        self.text = text


def LoadKaldiArc(filepath):
    utts = {}
    with open(filepath, 'r', encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                cols = line.split(maxsplit=1)
                assert(len(cols) == 2 or len(cols) == 1)
                uid = cols[0]
                text = cols[1] if len(cols) == 2 else ''
                if utts.get(uid) != None:
                    raise RuntimeError(F'Found duplicated utterence id {uid}')
                utts[uid] = Utterance(uid, text)
    return utts


def BreakHyphen(token : str):
    # 'T-SHIRT' should also introduce new words into vocabulary, e.g.:
    #   1. 'T' & 'SHIRT'
    #   2. 'TSHIRT'
    assert('-' in token)
    v = token.split('-')
    v.append(token.replace('-', ''))
    return v


def LoadGLM(rel_path):
    '''
    glm.csv:
        I'VE,I HAVE
        GOING TO,GONNA
        ...
        T-SHIRT,T SHIRT,TSHIRT

    glm:
        {
            '<RULE_00000>': ["I'VE", 'I HAVE'],
            '<RULE_00001>': ['GOING TO', 'GONNA'],
            ...
            '<RULE_99999>': ['T-SHIRT', 'T SHIRT', 'TSHIRT'],
        }
    '''
    logging.info(f'Loading GLM from {rel_path} ...')

    abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
    reader = list(
        csv.reader(open(abs_path, encoding="utf-8"), delimiter=',')
    )

    glm = {}
    for k, rule in enumerate(reader):
        rule_name = f'<RULE_{k:06d}>'
        glm[rule_name] = [ phrase.strip() for phrase in rule ]
    logging.info(f'  #rule: {len(glm)}')

    return glm


def SymbolEQ(symbol_table, i1, i2):
    return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#')


def PrintSymbolTable(symbol_table: pynini.SymbolTable):
    print('SYMBOL_TABLE:')
    for k in range(symbol_table.num_symbols()):
        sym = symbol_table.find(k)
        assert(symbol_table.find(sym) == k) # symbol table's find can be used for bi-directional lookup (id <-> sym)
        print(k, sym)
    print()


def BuildSymbolTable(vocab) -> pynini.SymbolTable :
    logging.info('Building symbol table ...')
    symbol_table = pynini.SymbolTable()
    symbol_table.add_symbol('<epsilon>')

    for w in vocab:
        symbol_table.add_symbol(w)
    logging.info(f'  #symbols: {symbol_table.num_symbols()}')

    #PrintSymbolTable(symbol_table)
    #symbol_table.write_text('symbol_table.txt')
    return symbol_table


def BuildGLMTagger(glm, symbol_table) -> pynini.Fst:
    logging.info('Building GLM tagger ...')
    rule_taggers = []
    for rule_tag, rule in glm.items():
        for phrase in rule:
            rule_taggers.append(
                (
                    pynutil.insert(pynini.accep(rule_tag, token_type = symbol_table)) +
                    pynini.accep(phrase, token_type = symbol_table) +
                    pynutil.insert(pynini.accep(rule_tag, token_type = symbol_table))
                )
            )

    alphabet = pynini.union(
        *[ pynini.accep(sym, token_type = symbol_table) for k, sym in symbol_table if k != 0 ] # non-epsilon
    ).optimize()

    tagger = pynini.cdrewrite(
        pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure()
    ).optimize() # could be slow with large vocabulary
    return tagger


def TokenWidth(token: str):
    def CharWidth(c):
        return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1
    return sum([ CharWidth(c) for c in token ])


def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream = sys.stderr):
    assert(len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali))

    H = '  HYP# : '
    R = '  REF  : '
    E = '  EDIT : '
    for i, e in enumerate(edit_ali):
        h, r = hyp_ali[i], ref_ali[i]
        e = '' if e == 'C' else e # don't bother printing correct edit-tag

        nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e)
        n = max(nr, nh, ne) + 1

        H += h + ' ' * (n-nh)
        R += r + ' ' * (n-nr)
        E += e + ' ' * (n-ne)

    print(F'  HYP  : {raw_hyp}', file = stream)
    print(H, file = stream)
    print(R, file = stream)
    print(E, file = stream)


def ComputeTokenErrorRate(c, s, i, d):
    assert((s + d + c) != 0)
    num_edits = (s + d + i)
    ref_len = (c + s + d)
    hyp_len = (c + s + i)
    return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len)


def ComputeSentenceErrorRate(num_err_utts, num_utts):
    assert(num_utts != 0)
    return 100.0 * num_err_utts / num_utts


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--logk', type = int, default = 500, help = 'logging interval')
    parser.add_argument('--tokenizer', choices = ['whitespace', 'char'], default = 'whitespace', help = 'whitespace for WER, char for CER')
    parser.add_argument('--glm', type = str, default = 'glm_en.csv', help = 'glm')
    parser.add_argument('--ref', type = str, required = True, help = 'reference kaldi arc file')
    parser.add_argument('--hyp', type = str, required = True, help = 'hypothesis kaldi arc file')
    parser.add_argument('result_file', type = str)
    args = parser.parse_args()
    logging.info(args)

    stats = ErrorStats()

    logging.info('Generating tokenizer ...')
    if args.tokenizer == 'whitespace':
        def word_tokenizer(text):
            return text.strip().split()
        tokenizer = word_tokenizer
    elif args.tokenizer == 'char':
        def char_tokenizer(text):
            return [ c for c in text.strip().replace(' ', '') ]
        tokenizer = char_tokenizer
    else:
        tokenizer = None
    assert(tokenizer)


    logging.info('Loading REF & HYP ...')
    ref_utts = LoadKaldiArc(args.ref)
    hyp_utts = LoadKaldiArc(args.hyp)

    # check valid utterances in hyp that have matched non-empty reference
    uids = []
    for uid in sorted(hyp_utts.keys()):
        if uid in ref_utts.keys():
            if ref_utts[uid].text.strip(): # non-empty reference
                uids.append(uid)
            else:
                logging.warning(F'Found {uid} with empty reference, skipping...')
        else:
            logging.warning(F'Found {uid} without reference, skipping...')
            stats.num_hyp_without_ref += 1

    stats.num_hyp_utts  = len(hyp_utts)
    stats.num_ref_utts  = len(ref_utts)
    stats.num_eval_utts = len(uids)
    logging.info(f'  #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')

    tokens = []
    for uid in uids:
        ref_tokens = tokenizer(ref_utts[uid].text)
        hyp_tokens = tokenizer(hyp_utts[uid].text)
        for t in ref_tokens + hyp_tokens:
            tokens.append(t)
            if '-' in t:
                tokens.extend(BreakHyphen(t))
    vocab_from_utts = list(set(tokens))
    logging.info(f'  HYP&REF vocab size: {len(vocab_from_utts)}')


    assert(args.glm)
    glm = LoadGLM(args.glm)

    tokens = []
    for rule in glm.values():
        for phrase in rule:
            for t in tokenizer(phrase):
                tokens.append(t)
                if '-' in t:
                    tokens.extend(BreakHyphen(t))
    vocab_from_glm = list(set(tokens))
    logging.info(f'  GLM vocab size: {len(vocab_from_glm)}')

    vocab = list(set(vocab_from_utts + vocab_from_glm))
    logging.info(f'Global vocab size: {len(vocab)}')

    symtab = BuildSymbolTable(
        # Normal evaluation vocab + auxiliary form for alternative paths + GLM tags
        vocab + [ x+'#' for x in vocab ] + [ x for x in glm.keys() ]
    )
    glm_tagger = BuildGLMTagger(glm, symtab)
    edit_transducer = EditTransducer(symbol_table = symtab, vocab = vocab)

    logging.info('Evaluating error rate ...')
    fo = open(args.result_file, 'w+', encoding='utf8')
    ndone = 0
    for uid in uids:
        ref = ref_utts[uid].text
        raw_hyp = hyp_utts[uid].text

        ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type = symtab)
        #print(ref_fst.string(token_type = symtab))

        raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type = symtab)
        #print(raw_hyp_fst.string(token_type = symtab))

        # Say, we have: 
        #   RULE_001: "I'M" <-> "I AM"
        #   REF: HEY I AM HERE
        #   HYP: HEY I'M HERE
        #
        # We want to expand HYP with GLM rules(marked with auxiliary #)
        #   HYP#: HEY {I'M | I# AM#} HERE
        # REF is honored to keep its original form.
        #
        # This could be considered as a flexible on-the-fly TN towards HYP.


        # 1. GLM rule tagging:
        #   HEY I'M HERE
        # ->
        #   HEY <RULE_001> I'M <RULE_001> HERE
        lattice = (raw_hyp_fst @ glm_tagger).optimize()
        tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type = symtab)
        #print(hyp_tagged)

        # 2. GLM rule expansion: 
        #   HEY <RULE_001> I'M <RULE_001> HERE
        # ->
        #   sausage-like fst: HEY {I'M | I# AM#} HERE
        tokens = tagged_ir.split()
        sausage = pynini.accep('', token_type = symtab)
        i = 0
        while i < len(tokens): # invariant: tokens[0, i) has been built into fst
            forms = []
            if tokens[i].startswith('<RULE_') and tokens[i].endswith('>'):  # rule segment
                rule_name = tokens[i]
                rule = glm[rule_name]
                # pre-condition: i -> ltag
                raw_form = ''
                for j in range(i+1, len(tokens)):
                    if tokens[j] == rule_name:
                        raw_form = ' '.join(tokens[i+1: j])
                        break
                assert(raw_form)
                # post-condition: i -> ltag, j -> rtag

                forms.append(raw_form)
                for phrase in rule:
                    if phrase != raw_form:
                        forms.append(' '.join([ x+'#' for x in phrase.split() ]))
                i = j + 1
            else:  # normal token segment
                token = tokens[i]
                forms.append(token)
                if "-" in token:  # token with hyphen yields extra forms
                    forms.append(' '.join([ x+'#' for x in token.split('-') ])) # 'T-SHIRT' -> 'T# SHIRT#'
                    forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#'
                i += 1

            sausage_segment = pynini.union(
                *[ pynini.accep(x, token_type=symtab) for x in forms ]
            ).optimize()
            sausage += sausage_segment
        hyp_fst = sausage.optimize()

        # Utterance-Level error rate evaluation
        alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst)

        distance = 0.0
        C, S, I, D = 0, 0, 0, 0  # Cor, Sub, Ins, Del
        edit_ali, ref_ali, hyp_ali = [], [], []
        for state in alignment.states():
            for arc in alignment.arcs(state):
                i, o = arc.ilabel, arc.olabel
                if i != 0 and o != 0 and SymbolEQ(symtab, i, o):
                    e = 'C'
                    r, h = symtab.find(i), symtab.find(o)

                    C += 1
                    distance += 0.0
                elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o):
                    e = 'S'
                    r, h = symtab.find(i), symtab.find(o)

                    S += 1
                    distance += 1.0
                elif i == 0 and o != 0:
                    e = 'I'
                    r, h = '*', symtab.find(o)

                    I += 1
                    distance += 1.0
                elif i != 0 and o == 0:
                    e = 'D'
                    r, h = symtab.find(i), '*'

                    D += 1
                    distance += 1.0
                else:
                    raise RuntimeError

                edit_ali.append(e)
                ref_ali.append(r)
                hyp_ali.append(h)
        #assert(distance == edit_transducer.compute_distance(ref_fst, sausage))

        utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D)
        print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo)
        PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo)

        if utt_ter > 0:
            stats.num_utts_with_error += 1

        stats.C += C
        stats.S += S
        stats.I += I
        stats.D += D

        ndone += 1
        if ndone % args.logk == 0:
            logging.info(f'{ndone} utts evaluated.')
    logging.info(f'{ndone} utts evaluated in total.')

    # Corpus-Level evaluation
    stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D)
    stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts)

    print(stats.to_json())
    print(stats.to_kaldi())
    print(stats.to_summary(), file=fo)

    fo.close()
